"""
手势识别
"""
import os
import cv2
import time
import numpy as np

class HandPoseModel(object):
    def __init__(self, model_path):
        self.numPoints = 22
        self.point_pairs = [[0,1],[1,2],[2,3],[3,4],
                            [0,5],[5,6],[6,7],[7,8],
                            [0,9],[9,10],[10,11],[11,12],
                            [0,13],[13,14],[14,15],[15,16],
                            [0,17],[17,18],[18,19],[19,20]]

        # self.inWidth = 368
        self.inHeight = 368
        self.threshold = 0.1

        self.hand_net = self.get_hand_model(model_path)
    
    '''
    Get Model from CaffeModel:
        pose_deploy.prototxt
        pose_iter_102000.caffemodel
    '''
    def get_hand_model(self, model_path):  
        prototxt = os.path.join(model_path, "pose_deploy.prototxt")
        caffemodel = os.path.join(model_path, "pose_iter_102000.caffemodel")
        model = cv2.dnn.readNetFromCaffe(prototxt, caffemodel)

        return model
    
    def predict(self, img):
        img_cv2 = cv2.imread(img)
        img_height, img_width, _ = img_cv2.shape
        aspect_ratio = img_width / img_height

        inWidth = int(((aspect_ratio * self.inHeight) * 8) // 8)
        inpBlob = cv2.dnn.blobFromImage(img_cv2, 1.0 / 255, (inWidth, self.inHeight), (0, 0, 0), swapRB=False, crop=False)

        self.hand_net.setInput(inpBlob)

        output = self.hand_net.forward()

        points = []
        for idx in range(self.numPoints):
            probMap = output[0, idx, :, :] # confidence map.
            probMap = cv2.resize(probMap, (img_width, img_height))

            # Find global maxima of the probMap.
            minVal, prob, minLoc, point = cv2.minMaxLoc(probMap)

            if prob > self.threshold:
                points.append((int(point[0]), int(point[1])))
            else:
                points.append(None)

        return points

if __name__ == "__main__":
    model_path = 'model/'
    image = 'images/hand_pose_2.jpg'
    start = time.time()

    model = HandPoseModel(model_path)
    
    print("[INFO] Model loads time: ", time.time() - start)

    points = model.predict(image)
    
    img = cv2.imread(image)
    for idx in range(len(points)-1):
        x, y = points[idx]
        cv2.circle(img, (int(x), int(y)), 5, (255,0,0), -1)
        cv2.putText(img, "{}".format(idx), points[idx], cv2.FONT_HERSHEY_SIMPLEX,
                            0.5, (0, 0, 255), 1, lineType=cv2.LINE_AA)

    cv2.imshow('Image', img)
    
    key = cv2.waitKey(0)
    if key == 27:
        cv2.destroyAllWindows()